package org.zjvis.datascience.common.graph.algo;

import org.zjvis.datascience.common.graph.util.GraphUtil;

import java.util.*;
import java.util.stream.Collectors;

/**
 * @description GraphDistanceWeighted
 * @date 2021-12-29
 */
public class GraphDistanceWeighted {
    private Object[] mVexs;       // 顶点集合
    private Integer mVertexNum;       // 顶点集合
    private Map<Object, Integer> id2index;
    private Map<Integer, Object> index2Id;
    private Double[][] mMatrix;    // 邻接矩阵
    private static final Double INF = Double.MAX_VALUE;   // 最大值
    private Integer[][] mNext;
    private Double[][] mDistance;

    private Double diameter;
    private Double radius;
    private Double avgDist;

    private Double[] betweenness;
    private Double[] closeness;
    private Double[] eccentricity;

    private Map<Object, Double> betweennessMap;
    private Map<Object, Double> closenessMap;
    private Map<Object, Double> eccentricityMap;


    public GraphDistanceWeighted(Map<Map<Object, Object>, List<Double>> map, List<Object> idList) {
        //初始化
        mVexs = idList.toArray(new Object[0]);
        mVertexNum = idList.size();
        id2index = new HashMap<>();
        index2Id = new HashMap<>();
        mMatrix = new Double[mVertexNum][mVertexNum];
        mDistance = new Double[mVertexNum][mVertexNum];
        mNext = new Integer[mVertexNum][mVertexNum];


        for (int i = 0; i < mVexs.length; i++) {
            id2index.put(mVexs[i], i);
            index2Id.put(i, mVexs[i]);
        }

        diameter = 0D;
        avgDist = 0D;
        radius = Double.MAX_VALUE;

        betweenness = new Double[mVertexNum];
        closeness = new Double[mVertexNum];
        eccentricity = new Double[mVertexNum];

        betweennessMap = new HashMap<>();
        closenessMap = new HashMap<>();
        eccentricityMap = new HashMap<>();

        // 初始化
        Map<String, Object> keyMap = new HashMap<>();
        for (int i = 0; i < mVexs.length; i++) {
            for (int j = 0; j < mVexs.length; j++) {
                Double dist;
                Integer next;
                if (i == j) {
                    dist = 0.0;
                    next = -1;
                } else {
                    Object srcId = index2Id.get(i);
                    Object tarId = index2Id.get(j);
                    keyMap.put("src", srcId);
                    keyMap.put("tar", tarId);
                    List<Double> weight = map.get(keyMap);
                    if (weight != null) {
                        dist = weight.get(0);
                        next = j;
                    } else {
                        dist = INF;
                        next = -1;
                    }
                }
                mMatrix[i][j] = dist;
                mDistance[i][j] = dist;
                mNext[i][j] = next;
            }
            closeness[i] = 0D;
            betweenness[i] = 0D;
            eccentricity[i] = 0D;
        }
    }

    public void executeFloyd() {
        floyd();
        calValue();
    }

    public void floyd() {
        // floyd算法计算最短路径
        for (int k = 0; k < mVexs.length; k++) {
            for (int i = 0; i < mVexs.length; i++) {
                for (int j = 0; j < mVexs.length; j++) {

                    double tmp = (mDistance[i][k].equals(INF) || mDistance[k][j].equals(INF)) ? INF : (mDistance[i][k] + mDistance[k][j]);
                    if (mDistance[i][j] > tmp) {
                        mDistance[i][j] = tmp;
                        mNext[i][j] = mNext[i][k];
                    }
                }
            }
        }
    }

    public List<Integer> findPath(Integer u, Integer v) {
        List<Integer> path = new ArrayList<>();
        if (mNext[u][v] == -1) {
            return path;
        }
        path.add(u);
        while (!u.equals(v)) {
            u = mNext[u][v];
            path.add(u);
        }
        return path;
    }

    public void calValue() {
        Double totalPath = 0.0D;
        for (int i = 0; i < mVexs.length; i++) {
            Double reachable = 0.0D;
            for (int j = 0; j < mVexs.length; j++) {
                Double dist = mDistance[i][j];
                if (dist.equals(INF) || dist == 0) {
                    continue;
                }
                avgDist += dist;
                diameter = Math.max(diameter, dist);
                closeness[i] += dist;
                eccentricity[i] = Math.max(eccentricity[i], dist);
                List<Integer> path = findPath(i, j);
                for (Integer p: path) {
                    if (p != i && p != j) {
                        betweenness[p] += 1;
                    }
                }
                reachable ++ ;
                totalPath ++ ;
            }
            closeness[i] =  closeness[i] == 0.0D ? 0.0D : reachable / closeness[i];
            Object id = index2Id.get(i);
            closenessMap.put(id, GraphUtil.round2DecimalDouble(closeness[i]));
            eccentricityMap.put(id, GraphUtil.round2DecimalDouble(eccentricity[i]));
            betweennessMap.put(id, GraphUtil.round2DecimalDouble(betweenness[i]));
        }
        avgDist = totalPath == 0 ? 0.0D : avgDist / totalPath;
        avgDist = GraphUtil.round2DecimalDouble(avgDist);
        diameter = GraphUtil.round2DecimalDouble(diameter);
    }

    public List<Object> dijkstra(Object srcId, Object tarId) {
        if (mVertexNum == 0) {
            return new ArrayList<>();
        }
        Integer u = id2index.get(srcId);
        Integer v = id2index.get(tarId);
        Double[] dist = new Double[mVertexNum];   //用于存放顶点0到其它顶点的最短距离
        Boolean[] used = new Boolean[mVertexNum];  //用于判断顶点是否被遍历
        Integer[] path = new Integer[mVertexNum];
        for(int i = 0; i < mVertexNum; i++) {
            Double distuv = mMatrix[u][i];
            dist[i] = distuv;
            used[i] = false;
            if (distuv.equals(INF)) {
                path[i] = -1;
            } else {
                path[i] = u;
            }
        }
        used[u] = true;
        for (int i = 0; i < mVertexNum; i++) {
            Double min = INF;
            Integer minIndex = u;
            for (int j = 0; j < mVertexNum; j++) {
                if (!used[j] && dist[j] < min) {
                    min = dist[j];
                    minIndex = j;
                }
            }
            used[minIndex] = true;
            for (int k = 0; k < mVertexNum; k++) {
                Double d = mMatrix[minIndex][k];
                if (!used[k] && !d.equals(INF) && dist[minIndex] + d < dist[k]) {
                    dist[k] = dist[minIndex] + d;
                    path[k] = minIndex;
                }
            }
        }
        if (dist[v].equals(INF)) {
            return new ArrayList<>();
        }
        List<Integer> shortestPath = new ArrayList<>();
        shortestPath.add(v);
        while (!v.equals(u)) {
            v = path[v];
            shortestPath.add(v);
        }
        Collections.reverse(shortestPath);
        List<Object> ret = shortestPath.stream().map(x->index2Id.get(x)).collect(Collectors.toList());
        return ret;
    }

    public Double getDiameter() {
        return diameter;
    }

    public Double getAvgDist() {
        return avgDist;
    }

    public Map<Object, Double> getBetweennessMap() {
        return betweennessMap;
    }

    public Map<Object, Double> getClosenessMap() {
        return closenessMap;
    }

    public Map<Object, Double> getEccentricityMap() {
        return eccentricityMap;
    }

}
